{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f8b0d8c2f50>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import collections\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "torch.set_printoptions(edgeitems=2)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = ['airplane','automobile','bird','cat','deer',\n",
    "               'dog','frog','horse','ship','truck']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets, transforms\n",
    "data_path = '../data/cifar/'\n",
    "cifar10 = datasets.CIFAR10(\n",
    "    data_path, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                             (0.2470, 0.2435, 0.2616))\n",
    "    ]))\n",
    "cifar10_val = datasets.CIFAR10(\n",
    "    data_path, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                             (0.2470, 0.2435, 0.2616))\n",
    "    ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {0: 0, 2: 1}\n",
    "class_names = ['airplane', 'bird']\n",
    "cifar2 = [(img, label_map[label])\n",
    "          for img, label in cifar10\n",
    "          if label in [0, 2]]\n",
    "cifar2_val = [(img, label_map[label])\n",
    "              for img, label in cifar10_val\n",
    "              if label in [0, 2]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3,16,kernel_size=3,padding=1)\n",
    "        self.conv2 = nn.Conv2d(16,8,kernel_size=3,padding=1)\n",
    "        self.fc1 = nn.Linear(8*8*8,32)\n",
    "        self.fc2 = nn.Linear(32,2)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        out = F.max_pool2d(torch.tanh(self.conv1(x)),2)  # 2 是max_pool2d 的参数\n",
    "        out = F.max_pool2d(torch.tanh(self.conv2(out)),2)\n",
    "        out = out.view(-1,8*8*8)\n",
    "        out = torch.tanh(self.fc1(out))\n",
    "        out = self.fc2(out)\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "img,_ = cifar2[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1054, 0.1153]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Net()\n",
    "model(img.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loop\n",
    "import datetime\n",
    "\n",
    "\n",
    "def training_loop(n_epochs,optimizer,model,loss_fn,train_loader):\n",
    "    for epoch in range(1,n_epochs+1):\n",
    "        loss_train = 0.0\n",
    "        for imgs,labels in train_loader:\n",
    "            \n",
    "            outputs = model(imgs)\n",
    "            loss = loss_fn(outputs,labels)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            \n",
    "            loss_train += loss.item()\n",
    "        if epoch==1 or epoch % 10 == 0:\n",
    "            print('{} Epoch {},Training loss {}'.format(\n",
    "                datetime.datetime.now(),epoch,loss_train/len(train_loader)))  # 计算每个batch的平均损失"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-07-19 17:14:25.352883 Epoch 1,Training loss 0.5701384637386177\n",
      "2020-07-19 17:14:45.797915 Epoch 10,Training loss 0.33377454121401356\n",
      "2020-07-19 17:15:08.995031 Epoch 20,Training loss 0.2964806552905186\n",
      "2020-07-19 17:15:31.422398 Epoch 30,Training loss 0.2676278444801926\n",
      "2020-07-19 17:15:53.826521 Epoch 40,Training loss 0.24656456239094401\n",
      "2020-07-19 17:16:16.482019 Epoch 50,Training loss 0.22621194576951348\n",
      "2020-07-19 17:16:44.447339 Epoch 60,Training loss 0.20755051261490318\n",
      "2020-07-19 17:17:06.357585 Epoch 70,Training loss 0.18867057958131384\n",
      "2020-07-19 17:17:28.880580 Epoch 80,Training loss 0.171582081040759\n",
      "2020-07-19 17:17:51.508608 Epoch 90,Training loss 0.15744791217860143\n",
      "2020-07-19 17:18:16.729977 Epoch 100,Training loss 0.14557443900852446\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(cifar2,batch_size=64,shuffle=True)\n",
    "model = Net()\n",
    "optimizer = optim.SGD(model.parameters(),lr = 1e-2)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "training_loop(n_epochs=100,\n",
    "             optimizer=optimizer,\n",
    "             model=model,\n",
    "             loss_fn=loss_fn,\n",
    "             train_loader=train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy train: 0.93\n",
      "Accuracy val: 0.87\n"
     ]
    }
   ],
   "source": [
    "# 计算精度,经过训练，权重还保存着\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(cifar2,batch_size=64,shuffle=True)\n",
    "val_loader = torch.utils.data.DataLoader(cifar2_val,batch_size=64,shuffle=True)\n",
    "\n",
    "def validate(model,train_loader,val_loader):\n",
    "    for name,loader in[('train',train_loader),('val',val_loader)]:\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for imgs,labels in loader:\n",
    "                outputs = model(imgs)\n",
    "                _,predicted = torch.max(outputs,dim=1)      # predicted 是索引\n",
    "                total += labels.shape[0]\n",
    "                correct +=int((predicted == labels).sum())\n",
    "              \n",
    "        print(\"Accuracy {}: {:.2f}\".format(name , correct / total))\n",
    "\n",
    "validate(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  保存模型\n",
    "\n",
    "torch.save(model.state_dict(),data_path + 'bird_vs_airplane.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loaded_model = Net()\n",
    "loaded_model.load_state_dict(torch.load(data_path + 'bird_vs_airplane.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on device cuda.\n"
     ]
    }
   ],
   "source": [
    "device = (torch.device('cuda') if torch.cuda.is_available()\n",
    "else torch.device('cpu'))\n",
    "print(f\"Training on device {device}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "def training_loop(n_epochs,optimizer,model,loss_fn,train_loader):\n",
    "    for epoch in range(1,n_epochs+1):\n",
    "        loss_train = 0.0\n",
    "        for imgs,labels in train_loader:\n",
    "            \n",
    "            imgs = img.to(device = device)          # 输入放到gpu\n",
    "            labels = labels.to(device=device)\n",
    "            \n",
    "            outputs = model(imgs)\n",
    "            \n",
    "            loss = loss_fn(outputs,labels)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            loss_train += loss.item()\n",
    "            \n",
    "        if epoch == 1 or epoch %10 ==0:\n",
    "            print('{} Epoch {},Training loss {}'.format(\n",
    "                datetime.datetime.now(),epoch,loss_train/len(train_loader)))\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-07-19 20:05:11.181773 Epoch 1,Training loss 0.6018951303639989\n",
      "2020-07-19 20:05:18.631052 Epoch 10,Training loss 0.34487766150835975\n",
      "2020-07-19 20:05:26.798797 Epoch 20,Training loss 0.30025195762230333\n",
      "2020-07-19 20:05:34.945215 Epoch 30,Training loss 0.2714038026180996\n",
      "2020-07-19 20:05:43.122595 Epoch 40,Training loss 0.25165232816699207\n",
      "2020-07-19 20:05:51.554141 Epoch 50,Training loss 0.23327995746568508\n",
      "2020-07-19 20:06:00.314958 Epoch 60,Training loss 0.2125224553665538\n",
      "2020-07-19 20:06:08.473854 Epoch 70,Training loss 0.1973991779385099\n",
      "2020-07-19 20:06:16.520341 Epoch 80,Training loss 0.18535497134468357\n",
      "2020-07-19 20:06:21.605332 Epoch 90,Training loss 0.16962906590123086\n",
      "2020-07-19 20:06:24.898841 Epoch 100,Training loss 0.1555779479966042\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(cifar2,batch_size=64,shuffle=True)\n",
    "\n",
    "model = Net().to(device=device)                          # 模型放到gpu（包含所有参数）\n",
    "\n",
    "optimizer = optim.SGD(model.parameters(),lr = 1e-2)\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "training_loop(\n",
    "    n_epochs = 100,\n",
    "    optimizer = optimizer,\n",
    "    model = model,\n",
    "    loss_fn = loss_fn,\n",
    "    train_loader = train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), data_path + 'birds_vs_airplanes_g.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loaded_model = Net().to(device = device)\n",
    "loaded_model.load_state_dict(torch.load(data_path +'birds_vs_airplanes_g.pt',map_location = device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## model design\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'avi'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str = 'v_ApplyEyeMakeup_g01_c01.avi'\n",
    "\n",
    "x = str.split('.')[1]\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
